import pickle

import numpy as np
import pytest

from tokenizers import AddedToken, Encoding, Tokenizer
from tokenizers.implementations import BertWordPieceTokenizer
from tokenizers.models import BPE, Model, WordPiece, Unigram
from tokenizers.normalizers import Lowercase
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.processors import BertProcessing, RobertaProcessing

from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, roberta_files


class TestAddedToken:
    def test_instantiate_with_content_only(self):
        added_token = AddedToken("<mask>")
        added_token.content = "<MASK>"
        assert added_token.content == "<MASK>"
        assert type(added_token) == AddedToken
        added_token.content = added_token.content.lower()

        assert added_token.special == False
        added_token.special = True
        assert added_token.special == True
        added_token.special = False
        assert str(added_token) == "<mask>"
        assert (
            repr(added_token)
            == 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False)'
        )
        assert added_token.rstrip == False
        assert added_token.lstrip == False
        assert added_token.single_word == False
        assert added_token.normalized == True
        assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)

    def test_can_set_rstrip(self):
        added_token = AddedToken("<mask>", rstrip=True)
        assert added_token.rstrip == True
        assert added_token.lstrip == False
        assert added_token.single_word == False
        assert added_token.normalized == True

    def test_can_set_lstrip(self):
        added_token = AddedToken("<mask>", lstrip=True)
        assert added_token.rstrip == False
        assert added_token.lstrip == True
        assert added_token.single_word == False
        assert added_token.normalized == True

    def test_can_set_single_world(self):
        added_token = AddedToken("<mask>", single_word=True)
        assert added_token.rstrip == False
        assert added_token.lstrip == False
        assert added_token.single_word == True
        assert added_token.normalized == True

    def test_can_set_normalized(self):
        added_token = AddedToken("<mask>", normalized=False)
        assert added_token.rstrip == False
        assert added_token.lstrip == False
        assert added_token.single_word == False
        assert added_token.normalized == False


class TestTokenizer:
    def test_has_expected_type_and_methods(self):
        tokenizer = Tokenizer(BPE())
        assert type(tokenizer) == Tokenizer
        assert callable(tokenizer.num_special_tokens_to_add)
        assert callable(tokenizer.get_vocab)
        assert callable(tokenizer.get_vocab_size)
        assert callable(tokenizer.enable_truncation)
        assert callable(tokenizer.no_truncation)
        assert callable(tokenizer.enable_padding)
        assert callable(tokenizer.no_padding)
        assert callable(tokenizer.encode)
        assert callable(tokenizer.encode_batch)
        assert callable(tokenizer.decode)
        assert callable(tokenizer.decode_batch)
        assert callable(tokenizer.token_to_id)
        assert callable(tokenizer.id_to_token)
        assert callable(tokenizer.add_tokens)
        assert callable(tokenizer.add_special_tokens)
        assert callable(tokenizer.train)
        assert callable(tokenizer.post_process)
        assert isinstance(tokenizer.model, Model)
        assert tokenizer.normalizer is None
        assert tokenizer.pre_tokenizer is None
        assert tokenizer.post_processor is None
        assert tokenizer.decoder is None
        assert isinstance(pickle.loads(pickle.dumps(Tokenizer(BPE()))), Tokenizer)

    def test_add_tokens(self):
        tokenizer = Tokenizer(BPE())
        added = tokenizer.add_tokens(["my", "name", "is", "john"])
        assert added == 4

        tokens = [AddedToken("the"), AddedToken("quick", normalized=False), AddedToken()]
        assert tokens[0].normalized == True
        added = tokenizer.add_tokens(tokens)
        assert added == 2
        assert tokens[0].normalized == True
        assert tokens[1].normalized == False

    def test_add_special_tokens(self):
        tokenizer = Tokenizer(BPE())

        # Can add special tokens as `str`
        added = tokenizer.add_special_tokens(["my", "name", "is", "john"])
        assert added == 4

        # Can add special tokens as `AddedToken`
        tokens = [AddedToken("the"), AddedToken("quick", normalized=True), AddedToken()]
        assert tokens[0].normalized == True
        added = tokenizer.add_special_tokens(tokens)
        assert added == 2
        assert tokens[0].normalized == False
        assert tokens[1].normalized == True

    def test_encode(self):
        tokenizer = Tokenizer(BPE())
        tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

        # Can encode single sequence
        output = tokenizer.encode("my name is john")
        assert output.tokens == ["my", "name", "is", "john"]
        assert type(output.ids) == list
        assert type(output.type_ids) == list
        assert type(output.offsets) == list
        with pytest.warns(DeprecationWarning):
            assert type(output.words) == list
        assert type(output.word_ids) == list
        assert type(output.special_tokens_mask) == list
        assert type(output.attention_mask) == list
        assert type(output.overflowing) == list

        # Can encode a pair of sequences
        output = tokenizer.encode("my name is john", "pair")
        assert output.tokens == ["my", "name", "is", "john", "pair"]
        assert isinstance(pickle.loads(pickle.dumps(output)), Encoding)

        # Can encode a single pre-tokenized sequence
        output = tokenizer.encode(["my", "name", "is", "john"], is_pretokenized=True)
        assert output.tokens == ["my", "name", "is", "john"]

        # Can encode a batch with both a single sequence and a pair of sequences
        output = tokenizer.encode_batch(["my name is john", ("my name is john", "pair")])
        assert len(output) == 2

    def test_encode_formats(self, bert_files):
        with pytest.deprecated_call():
            tokenizer = BertWordPieceTokenizer(bert_files["vocab"])

        # Encode
        output = tokenizer.encode("my name is john")
        assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
        output = tokenizer.encode("my name is john", "pair")
        assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]", "pair", "[SEP]"]
        output = tokenizer.encode(["my", "name", "is", "john"], is_pretokenized=True)
        assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
        output = tokenizer.encode(["my", "name", "is", "john"], ["pair"], is_pretokenized=True)
        assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]", "pair", "[SEP]"]

        # Encode batch
        result_single = [
            ["[CLS]", "my", "name", "is", "john", "[SEP]"],
            ["[CLS]", "my", "name", "is", "georges", "[SEP]"],
        ]
        result_pair = [
            ["[CLS]", "my", "name", "is", "john", "[SEP]", "pair", "[SEP]"],
            ["[CLS]", "my", "name", "is", "georges", "[SEP]", "pair", "[SEP]"],
        ]

        def format(encodings):
            return [e.tokens for e in encodings]

        def test_single(input, is_pretokenized=False):
            output = tokenizer.encode_batch(input, is_pretokenized=is_pretokenized)
            assert format(output) == result_single

        def test_pair(input, is_pretokenized=False):
            output = tokenizer.encode_batch(input, is_pretokenized=is_pretokenized)
            assert format(output) == result_pair

        # Classic inputs

        # Lists
        test_single(["My name is John", "My name is Georges"])
        test_pair([("my name is john", "pair"), ("my name is georges", "pair")])
        test_pair([["my name is john", "pair"], ["my name is georges", "pair"]])

        # Tuples
        test_single(("My name is John", "My name is Georges"))
        test_pair((("My name is John", "pair"), ("My name is Georges", "pair")))

        # Numpy
        test_single(np.array(["My name is John", "My name is Georges"]))
        test_pair(np.array([("My name is John", "pair"), ("My name is Georges", "pair")]))
        test_pair(np.array([["My name is John", "pair"], ["My name is Georges", "pair"]]))

        # PreTokenized inputs

        # Lists
        test_single([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]], True)
        test_pair(
            [
                (["My", "name", "is", "John"], ["pair"]),
                (["My", "name", "is", "Georges"], ["pair"]),
            ],
            True,
        )
        test_pair(
            [
                [["My", "name", "is", "John"], ["pair"]],
                [["My", "name", "is", "Georges"], ["pair"]],
            ],
            True,
        )

        # Tuples
        test_single((("My", "name", "is", "John"), ("My", "name", "is", "Georges")), True)
        test_pair(
            (
                (("My", "name", "is", "John"), ("pair",)),
                (("My", "name", "is", "Georges"), ("pair",)),
            ),
            True,
        )
        test_pair(
            (
                (["My", "name", "is", "John"], ["pair"]),
                (["My", "name", "is", "Georges"], ["pair"]),
            ),
            True,
        )

        # Numpy
        test_single(
            np.array([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]]),
            True,
        )
        test_single(
            np.array((("My", "name", "is", "John"), ("My", "name", "is", "Georges"))),
            True,
        )
        test_pair(
            np.array(
                [
                    [["My", "name", "is", "John"], ["pair"]],
                    [["My", "name", "is", "Georges"], ["pair"]],
                ],
                dtype=object,
            ),
            True,
        )
        test_pair(
            np.array(
                (
                    (("My", "name", "is", "John"), ("pair",)),
                    (("My", "name", "is", "Georges"), ("pair",)),
                ),
                dtype=object,
            ),
            True,
        )

        # Mal formed
        with pytest.raises(TypeError, match="TextInputSequence must be str"):
            tokenizer.encode([["my", "name"]])
        with pytest.raises(TypeError, match="TextInputSequence must be str"):
            tokenizer.encode("My name is john", [["pair"]])
        with pytest.raises(TypeError, match="TextInputSequence must be str"):
            tokenizer.encode("my name is john", ["pair"])

        with pytest.raises(TypeError, match="InputSequence must be Union[List[str]"):
            tokenizer.encode("My name is john", is_pretokenized=True)
        with pytest.raises(TypeError, match="InputSequence must be Union[List[str]"):
            tokenizer.encode("My name is john", ["pair"], is_pretokenized=True)
        with pytest.raises(TypeError, match="InputSequence must be Union[List[str]"):
            tokenizer.encode(["My", "name", "is", "John"], "pair", is_pretokenized=True)

    def test_encode_add_special_tokens(self, roberta_files):
        with pytest.deprecated_call():
            tokenizer = Tokenizer(BPE(roberta_files["vocab"], roberta_files["merges"]))
        tokenizer.add_special_tokens(["<s>", "</s>"])

        tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
        tokenizer.post_processor = RobertaProcessing(
            ("</s>", tokenizer.token_to_id("</s>")),
            ("<s>", tokenizer.token_to_id("<s>")),
        )

        # Can encode with special tokens
        output_with_specials = tokenizer.encode("My name is John", add_special_tokens=True)
        assert output_with_specials.tokens == ["<s>", "ĠMy", "Ġname", "Ġis", "ĠJohn", "</s>"]

        # Can encode without special tokens
        output_without_specials = tokenizer.encode("My name is John", add_special_tokens=False)
        assert output_without_specials.tokens == ["ĠMy", "Ġname", "Ġis", "ĠJohn"]

    def test_truncation(self):
        tokenizer = Tokenizer(BPE())
        tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
        tokenizer.enable_truncation(2)

        # Can truncate single sequences
        output = tokenizer.encode("my name is john")
        assert output.tokens == ["my", "name"]

        # Can truncate pair sequences as well
        output = tokenizer.encode("my name is john", "pair")
        assert output.tokens == ["my", "pair"]

        # Can get the params and give them to enable_truncation
        trunc = tokenizer.truncation
        tokenizer.enable_truncation(**trunc)

        # Left truncation direction
        tokenizer.enable_truncation(2, direction="left")
        output = tokenizer.encode("my name is john")
        assert output.tokens == ["is", "john"]

        output = tokenizer.encode("my name is john", "pair")
        assert output.tokens == ["john", "pair"]

    def test_padding(self):
        tokenizer = Tokenizer(BPE())
        tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

        # By default it does nothing when encoding single sequence
        tokenizer.enable_padding()
        output = tokenizer.encode("my name")
        assert output.tokens == ["my", "name"]

        # Can pad to the longest in a batch
        output = tokenizer.encode_batch(["my name", "my name is john"])
        assert all([len(encoding) == 4 for encoding in output])

        # Can pad to the specified length otherwise
        tokenizer.enable_padding(length=4)
        output = tokenizer.encode("my name")
        assert output.tokens == ["my", "name", "[PAD]", "[PAD]"]
        output = tokenizer.encode("my name", "pair")
        assert output.tokens == ["my", "name", "pair", "[PAD]"]

        # Can get the params and give them to enable_padding
        padding = tokenizer.padding
        tokenizer.enable_padding(**padding)

    def test_decode(self):
        tokenizer = Tokenizer(BPE())
        tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

        # Can decode single sequences
        output = tokenizer.decode([0, 1, 2, 3])
        assert output == "my name is john"

        # Can decode batch
        output = tokenizer.decode_batch([[0, 1, 2, 3], [4]])
        assert output == ["my name is john", "pair"]

    def test_get_vocab(self):
        tokenizer = Tokenizer(BPE())
        tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

        # Can retrieve vocab with added tokens
        vocab = tokenizer.get_vocab(with_added_tokens=True)
        assert vocab == {"is": 2, "john": 3, "my": 0, "name": 1, "pair": 4}

        # Can retrieve vocab without added tokens
        vocab = tokenizer.get_vocab(with_added_tokens=False)
        assert vocab == {}

        # Can retrieve added token decoder
        vocab = tokenizer.get_added_tokens_decoder()
        assert vocab == {
            0: AddedToken("my", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
            1: AddedToken("name", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
            2: AddedToken("is", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
            3: AddedToken("john", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
            4: AddedToken("pair", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
        }

    def test_get_vocab_size(self):
        tokenizer = Tokenizer(BPE())
        tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

        # Can retrieve vocab's size with added tokens
        size = tokenizer.get_vocab_size(with_added_tokens=True)
        assert size == 5

        # Can retrieve vocab's size without added tokens
        size = tokenizer.get_vocab_size(with_added_tokens=False)
        assert size == 0

    def test_post_process(self):
        tokenizer = Tokenizer(BPE())
        tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
        tokenizer.enable_truncation(2)
        tokenizer.enable_padding(length=4)

        encoding = tokenizer.encode("my name is john")
        pair_encoding = tokenizer.encode("pair")

        # Can post process a single encoding
        output = tokenizer.post_process(encoding)
        assert output.tokens == ["my", "name", "[PAD]", "[PAD]"]

        # Can post process a pair of encodings
        output = tokenizer.post_process(encoding, pair_encoding)
        assert output.tokens == ["my", "pair", "[PAD]", "[PAD]"]

    def test_multiprocessing_with_parallelism(self):
        tokenizer = Tokenizer(BPE())
        multiprocessing_with_parallelism(tokenizer, False)
        multiprocessing_with_parallelism(tokenizer, True)

    def test_from_pretrained(self):
        tokenizer = Tokenizer.from_pretrained("bert-base-cased")
        output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
        assert output.tokens == ["Hey", "there", "dear", "friend", "!"]

    def test_from_pretrained_revision(self):
        tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test")
        output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
        assert output.tokens == ["hey", "there", "dear", "friend", "!"]

        tokenizer = Tokenizer.from_pretrained("anthony/tokenizers-test", revision="gpt-2")
        output = tokenizer.encode("Hey there dear friend!", add_special_tokens=False)
        assert output.tokens == ["Hey", "Ġthere", "Ġdear", "Ġfriend", "!"]

    def test_unigram_byte_fallback(self):
        vocab = [
            ("<unk>", 0.0),
            ("A", -0.01),
            ("sen", -0.02),
            ("te", -0.03),
            ("n", -0.04),
            ("ce", -0.05),
            ("<0xF0>", -0.06),
            ("<0x9F>", -0.06),
            ("<0xA4>", -0.06),
            ("<0x97>", -0.06),
            (" ", -0.4),
        ]
        tokenizer = tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=False))

        output = tokenizer.encode("A sentence 🤗")
        assert output.ids == [1, 10, 2, 3, 4, 5, 10, 0]
        assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "🤗"]

        tokenizer = Tokenizer(Unigram(vocab, 0, byte_fallback=True))

        output = tokenizer.encode("A sentence 🤗")
        assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9]
        assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"]

    def test_encode_special_tokens(self):
        tokenizer = Tokenizer.from_pretrained("t5-base")
        tokenizer.add_tokens(["<eot>"])
        tokenizer.add_special_tokens(["<end_of_text>"])
        output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
        assert output.tokens == ["▁Hey", "▁there", "<end_of_text>", "▁dear", "<eot>", "▁friend", "!"]

        tokenizer.encode_special_tokens = True
        assert tokenizer.encode_special_tokens == True

        output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
        assert output.tokens == [
            "▁Hey",
            "▁there",
            "<",
            "end",
            "_",
            "of",
            "_",
            "text",
            ">",
            "▁dear",
            "<eot>",
            "▁friend",
            "!",
        ]

        tokenizer.add_tokens(["of_text>"])
        output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
        assert output.tokens == ["▁Hey", "▁there", "<", "end", "_", "of_text>", "▁dear", "<eot>", "▁friend", "!"]
